module PreferenceModels
#%%
using DataFrames
using DataFramesMeta
using Optim
using Distributed
using StatsBase:mean
using Parameters
using LineSearches
using LinearAlgebra: diag
using ForwardDiff
using BlackBoxOptim
using NLopt
using JuMP
using Ipopt





export FSR, utility

GlobParams = Vector{Symbol}

sigmoid(x::Real) = begin
    res = (one(x) / (one(x) + exp(-x)))
    if res > 0.999999
        return 0.999999
    elseif res < 0.0000001
        return 0.0000001
    else
        return res
    end
end

logit(x::Real) =  begin
    if x < 0.000001
        x = 0.0000001
    elseif x > 0.99999999
        x = 0.99999999
    end
    return log(x/(one(x) - x))
end


function softmax(xs::AbstractArray)
    max_ = maximum(xs)
    exp_ = exp.(xs .- max_)
    exp_ ./ sum(exp_)
end

function softmax(a,b)
    softmax([a,b])
end

function vec_softmax(u_X,u_Y)
    maxes = max.(u_X, u_Y)
    x = exp.(u_X .- maxes)
    y = exp.(u_Y .- maxes)
    px = x ./ (x .+ y)
    py = y ./ (x .+ y)
    (;px, py)
end

function inv_softmax(xs::AbstractArray)
    log.(xs) .+ 2
end

abstract type Preference end


function get_fields(σ::Preference)
    fields = fieldnames(typeof(σ))
    fields
end

function get_params(σ::Preference)
    getfield.([σ], fieldnames(typeof(σ)));
end

function get_params_for_opt(σ::Preference)
    get_params(σ)
end

function params_from_opt(x::AbstractArray, σ::Preference)
    x
end

calc_probs(preds, y) = preds.*y .+ (1 .- preds).*(1 .- y) .+ eps(0.0)
ll(preds, y) = log.(calc_probs(preds, y))
accu(preds, y) = (calc_probs(preds, y) .> 0.5)*1
mean_ll(preds, y) = mean(log.(calc_probs(preds, y)))


len(σ::Preference) = length(get_params_for_opt(σ))

copy_funs!(σ1, σ2) = nothing

nunique(x) = length(unique(x))


#####################################################################
#%% Fehr-Schimdt with reciprocity Preferences
@with_kw mutable struct FSR <: Preference
    α::Real
    β::Real
    λ::Real
end

FSR() = FSR(0.5, 0.5, 0.01)


function get_fields(σ::FSR)
    fields = fieldnames(typeof(σ))
    fields
end

function get_params(σ::FSR; ps = fieldnames(FSR))
    getfield.([σ], ps);
end


function params_from_opt(x::AbstractArray, σ::FSR; ps=fieldnames(FSR))
    gs = filter(x -> x ∉ ps, fieldnames(FSR))
    (;zip([ps..., gs...], [x..., getfield.([σ], gs)...])...)
end

function utility(σ::FSR, π_A, π_B, s, r, q, v)
    (1 - σ.α*s -σ.β*r)*π_A + (σ.α*s + σ.β*r)*π_B
end

function utility(σ::FSR, π_A::Vector, π_B::Vector, s::Vector, r::Vector)
    (1 .- σ.α .*s .-σ.β .*r).*π_A .+ (σ.α .*s .+ σ.β .*r).*π_B
end

function predict(X::DataFrame, σ::FSR)
    u_X = utility(σ, X.self_x, X.other_x, X.s_x, X.r_x)
    u_Y = utility(σ, X.self_y, X.other_y, X.s_y, X.r_y)
    probs = vec_softmax(σ.λ .* u_X, σ.λ .* u_Y)
    probs[:px]
end


function get_params_for_opt(σ::FSR)
    x = get_params(σ)
    x
end

function set_from_opt!(x::AbstractArray, σ::FSR; ps=fieldnames(FSR))
    for (field, val) in zip(ps, x)
        setfield!(σ, field, val)
    end
    σ
end


lower(σ::FSR) = [-1.1, -1.1, 0.001]
function lower(σ::FSR, ps)
    all = (; zip(get_fields(σ), lower(σ))...)
    Float64.([all[k] for k in ps])
end

upper(σ::FSR) = [1.1, 1.1, 1.1]
function upper(σ::FSR, ps)
    all = (; zip(get_fields(σ), upper(σ))...)
    Float64.([all[k] for k in ps])
end


#####################################################################
#%% Decency with fixed type

 @with_kw mutable struct Decency <: Preference
    δ::Real = 0.5 
    # π::Real = 0.5 # n 
    λ::Real = 0.01 # = ξ
    β::Real = 0.5 # 
    α::Real = 0.5 # 
    γ::Real = 0.5 # = λ  
    κ::Real = 0.5 # = κ    
    communal::Bool = true
    harm::Bool = true
    e_vals_from_x::Bool = false
    used::Dict{Any,Any}=Dict(:δ=> true, :λ=> true, :β=> true, :α=> true, :γ=> true, :κ => true, :communal => false, :harm => false, :used => false, :lower => false, :upper=> false, :e_vals_from_x=>false)
    lower::Dict{Any,Any}=Dict(:δ=> -2., :λ=> 0.0001, :β=> 0.0, :α=>0.0, :γ=>0.0, :κ=>0.0)
    upper::Dict{Any,Any}=Dict(:δ=> 2., :λ=> 0.3, :β=> 1., :α=>10., :γ=>2., :κ=>3.)
end


function ParamsDiff(σ::Decency;change_used::Dict=Dict(), change_lower::Dict=Dict(), change_upper::Dict=Dict(), change_vals::Dict=Dict())
    for (key, value) in change_used
        σ.used[key] = value
    end
    for (key, value) in change_lower
        σ.lower[key] = value
    end
    for (key, value) in change_upper
        σ.upper[key] = value
    end
    for (key, value) in change_vals
        setfield!(σ, key, value)
    end
    σ
end


function get_fields(σ::Decency)
    fields = filter(x-> σ.used[x], fieldnames(Decency))
    fields
end

function get_params(σ::Decency; ps = nothing)
    if ps === nothing
        ps = get_fields(σ)
    end
    getfield.([σ], ps);
end


function params_from_opt(x::AbstractArray, σ::Decency; ps=get_fields(Decency()))
    gs = filter(x -> x ∉ ps, get_fields(σ))
    (;zip([ps..., gs...], [x..., getfield.([σ], gs)...])...)
end


function get_params_for_opt(σ::Decency)
    fields = get_fields(σ)
    [getfield(σ, x) for x in fields]
end

function set_from_opt!(x::AbstractArray, σ::Decency; ps=get_fields(Decency()))
    for (field, val) in zip(ps, x)
        setfield!(σ, field, val)
    end
    σ
end


function lower(σ::Decency)
    ps = get_fields(σ)
    Float64.([σ.lower[k] for k in ps])
end

function lower(σ::Decency, ps)
    Float64.([σ.lower[k] for k in ps])
end

function upper(σ::Decency)
    ps = get_fields(σ)
    Float64.([σ.upper[k] for k in ps])
end

function upper(σ::Decency, ps)
    Float64.([σ.upper[k] for k in ps])
end

len(σ::Decency) = length(get_fields(σ))


function maxmin(outcomes, dg::Vector)
    A_dg = max.(outcomes[:X][1], outcomes[:Y][1])
    B_dg = min.(outcomes[:X][2], outcomes[:Y][2])
    A_maxmin = dg .* A_dg 
    B_maxmin = dg .* B_dg 
    (;A=A_maxmin, B=B_maxmin)
end

function selfish(outcomes)
    A_xy = hcat(outcomes[:X][1], outcomes[:Y][1])
    B_xy = hcat(outcomes[:X][2], outcomes[:Y][2])

    A_self = argmax(A_xy, dims=2)
    A_xy_self = A_xy[A_self][:]
    B_xy_self = B_xy[A_self][:]

    A_self_val = A_xy_self
    B_self_val = B_xy_self
    (;A=A_self_val, B=B_self_val)
end

function communal_values(σ::Decency, outcomes, α)
    X = outcomes[:X]
    Y = outcomes[:Y]
    X_c = ( X[1] .+ X[2]) - α .* abs.(X[1] .- X[2])
    Y_c = ( Y[1] .+ Y[2]) - α .* abs.(Y[1] .- Y[2])
    (X=X_c, Y=Y_c)
end

function calc_c_opt_vals(outcomes, c_max, c_vals)
    c_opt_vals = (;A=outcomes.X[1] .* (c_max .== c_vals.X) + outcomes.Y[1] .* (c_max .== c_vals.Y), B=outcomes.X[2] .* (c_max .== c_vals.X) + outcomes.Y[2])
    c_opt_vals
end

function calc_hyp_e(c_vals, selfish_vals, β)
    A_e = β .* c_vals.A .+ (1 - β) .* selfish_vals.A
    B_e = β .* c_vals.B .+ (1 - β) .* selfish_vals.B
    (; A=A_e, B=B_e)
end

function calc_entitlements(outcomes, outcomes_M, σ::Decency, dg::Vector)
    selfish_vals = selfish(outcomes)
    c_vals = communal_values(σ, outcomes, σ.α)
    c_max = max.(c_vals.X, c_vals.Y)
    c_opt_vals = calc_c_opt_vals(outcomes, c_max, c_vals)
    hyp_e = calc_hyp_e(c_opt_vals, selfish_vals, σ.β)
    W_min = minimum(max.(hyp_e.B .- outcomes_M.B, 0), dims=2)[:]
    W = argmin(max.(hyp_e.B .- outcomes_M.B, 0), dims=2)
    Ae = outcomes_M.A[W][:] .* (W_min .!= 0) .+ hyp_e.A .* (W_min .== 0)
    Be = outcomes_M.B[W][:] .* (W_min .!= 0) .+ hyp_e.B .* (W_min .== 0)
    (;A=Ae, B=Be)
end


function tot_harm(A::Vector, B::Vector, e_A::Vector, e_B::Vector)
    max.([0], e_A .- A) + max.([0], e_B .- B)
end

function ind_harms(A::Vector, B::Vector, e_A::Vector, e_B::Vector)
    (;A=max.([0], e_A .- A),  B=max.([0], e_B .- B))
end

function calc_harms(outcomes, e_vals)
    X = outcomes[:X]
    Y = outcomes[:Y]
    X_harm = tot_harm(X[1], X[2], e_vals[:A], e_vals[:B])
    Y_harm = tot_harm(Y[1], Y[2], e_vals[:A], e_vals[:B])
    (X=X_harm, Y=Y_harm)
end

function calc_ind_harms(outcomes, e_vals)
    X = outcomes[:X]
    Y = outcomes[:Y]
    X_harm = ind_harms(X[1], X[2], e_vals[:A], e_vals[:B])
    Y_harm = ind_harms(Y[1], Y[2], e_vals[:A], e_vals[:B])
    (X=X_harm, Y=Y_harm)
end


function calc_permitted(harms, ind_harms)
    min_xy = min.(harms[:X], harms[:Y])
    min_xy_a = min.(ind_harms[:X].A, ind_harms[:Y].A)
    X_permitted = (ind_harms[:X].B .<= ind_harms[:Y].B)
    Y_permitted = (ind_harms[:Y].B .<= ind_harms[:X].B)
    (X=X_permitted, Y=Y_permitted)
end



function calc_permitted_c(c_values, permitted)
    max_xy = max.(c_values[:X], c_values[:Y])
    X_permitted = (c_values[:X] .>= c_values[:Y]) 
    Y_permitted = (c_values[:Y] .>= c_values[:X])
    (X=X_permitted, Y=Y_permitted)
end


function calc_blame_harm(σ::Decency, outcomes, permitted , e_vals)
    ind_harms = calc_ind_harms(outcomes, e_vals)
    X_blamew_harm = max.(0, σ.harm .* (.!permitted[:X]) .* (max.(0, ind_harms[:X][:B] .- ind_harms[:Y][:B])))
    Y_blamew_harm = max.(0, σ.harm .* (.!permitted[:Y]) .* (max.(0, ind_harms[:Y][:B] .- ind_harms[:X][:B])))

    X_blame_harm = X_blamew_harm
    Y_blame_harm = Y_blamew_harm
    (X=X_blamew_harm, Y=Y_blamew_harm)
end

function calc_blame_communal(σ::Decency, outcomes, c_vals, permitted_c)
    c_max = max.(c_vals.X, c_vals.Y)
    c_opt_vals = calc_c_opt_vals(outcomes, c_max, c_vals)
    X_blame_communal = σ.communal .* (.!permitted_c[:X]) .* max.(0, c_opt_vals.B .- outcomes[:X][2])
    Y_blame_communal = σ.communal .* (.!permitted_c[:Y]) .* max.(0, c_opt_vals.B .- outcomes[:Y][2])
    (X=X_blame_communal, Y=Y_blame_communal) 
end

function calc_blame_comb(σ::Decency, outcomes, permitted, c_vals, permitted_c, e_vals, γ)
    blame_harm = calc_blame_harm(σ, outcomes, permitted, e_vals)
    blame_communal = calc_blame_communal(σ, outcomes, c_vals, permitted_c)
    (X= blame_harm[:X] .+ γ.*blame_communal[:X], Y= blame_harm[:Y] .+ γ.*blame_communal[:Y])
end

function calc_utility(σ::Decency, outcomes, blame)
    X_utility =  outcomes[:X][1] .- σ.δ .* blame[:X] .- σ.κ .* max.(0, outcomes[:X][2] .- outcomes[:X][1])
    Y_utility =  outcomes[:Y][1] .- σ.δ .* blame[:Y] .- σ.κ .* max.(0, outcomes[:Y][2] .- outcomes[:Y][1])
    (X=X_utility, Y=Y_utility)
end

function calc_choice_probs(σ::Decency, utility)
    X_prob = 1 ./ (1 .+ exp.( σ.λ .* (utility[:Y] .- utility[:X])))
    Y_prob = 1 .- X_prob
    (X=X_prob, Y=Y_prob)
end


function predict(X, outcomes, outcomes_M, σ::Decency)
    α = σ.α
    γ = σ.γ
    if σ.e_vals_from_x
        e_self = X.e_A
        e_other = X.e_B
        e_vals = (;A=e_self, B=e_other)
    else
        e_vals = calc_entitlements(outcomes, outcomes_M, σ, X.dg)
    end
    harms = calc_harms(outcomes, e_vals)
    ind_harms = calc_ind_harms(outcomes, e_vals)
    permitted = calc_permitted(harms, ind_harms) 
    c_vals = communal_values(σ, outcomes, α) 
    permitted_c = calc_permitted_c(c_vals, permitted) 
    blame = calc_blame_comb(σ, outcomes, permitted, c_vals, permitted_c, e_vals, γ) 
    utility = calc_utility(σ, outcomes, blame)  
    choice_probs = vec_softmax(σ.λ .* utility[:X], σ.λ .* utility[:Y])  # här
    if any(abs.(choice_probs[:px] .- 1) .< 10^(-30)) | any(abs.(choice_probs[:px]) .< 10^(-30))
        choice_probs = vec_softmax(σ.λ .* utility[:X]./100, σ.λ .* utility[:Y]./100)
    end
    choice_probs[:px]
end

function predict(X::DataFrame, σ::Decency)
    outcomes = (X=(X.self_x, X.other_x), Y=(X.self_y, X.other_y))
    outcomes_M = (;A=Matrix(X[:, [:self_x, :self_y]]), B=Matrix(X[:, [:other_x, :other_y]]))
    predict(X, outcomes, outcomes_M, σ)
end


###################################################################
#%% Mixture-Model

mutable struct MixModel
    types::Vector
    shares::Vector{Real}
    K::Int64
    glob_params::GlobParams
    local_params::Vector{GlobParams}
end



function MixModel(types::Vector, glob_params::GlobParams)
    shares = rand(length(types))
    shares = shares./sum(shares)
    local_params = [[filter(x -> x ∉ glob_params, get_fields(σ))...] for σ in types]
    MixModel(types, shares, length(shares), glob_params, local_params)
end

function MixModel(types::Vector, shares, glob_params::GlobParams)
    local_params = [[filter(x -> x ∉ glob_params, get_fields(σ))...] for σ in types]
    MixModel(types, shares, length(shares), glob_params, local_params)
end

function MixModel(types::Vector)
    shares = rand(length(types))
    shares = shares./sum(shares)
    glob_params = GlobParams()
    local_params = [[filter(x -> x ∉ glob_params, get_fields(σ))...] for σ in types]
    MixModel(types, shares, length(shares), glob_params, local_params)
end

function get_params(m::MixModel)
    if m.K == 1
        return [get_params(m.types[1])..., get_params(m.types[1])...]
    else
        return vcat([get_params(m.types[i]) for i in 1:m.K]..., m.shares)
    end
end


function set_params(x, m::MixModel)
    lg = length(m.glob_params)
    if m.K == 1
        m.types[1] = typeof(m.types[1])(x...)
    else
        idx = 1
        for k in 1:m.K
            ll = length(get_params(m.types[k]))
            new_idx = idx+ll
            m.types[k] = typeof(m.types[k])(;(;zip(get_fields(m.types[k]), x[idx:new_idx-1])...)...)
            idx = new_idx
        end
        m.shares = x[end-m.K+1:end]
    end
    m
end

function set_glob_params!(x, m::MixModel)
    for k in 1:m.K
        set_from_opt!(x, m.types[k]; ps=m.glob_params)
    end
    m
end

function get_params_for_opt(m::MixModel)
    if m.K == 1
        return Float64.([get_params(m.types[1]; ps=m.local_params[1])..., get_params(m.types[1]; ps=m.glob_params)...])
    else
        return Float64.(vcat([get_params(m.types[i]; ps=m.local_params[i]) for i in 1:m.K]..., inv_softmax(m.shares), get_params(m.types[1]; ps=m.glob_params)))
    end
end

function set_params_from_opt!(x, m::MixModel)
    lg = length(m.glob_params)
    if m.K == 1
        ll = length(m.local_params[1])
        set_from_opt!(x[1:ll], m.types[1]; ps=m.local_params[1])
        set_from_opt!(x[end-lg+1:end], m.types[1]; ps=m.glob_params)
    else
        idx = 1
        for k in 1:m.K
            ll = length(m.local_params[k])
            new_idx = idx+ll
            set_from_opt!(x[idx:new_idx-1], m.types[k]; ps=m.local_params[k])
            set_from_opt!(x[end-lg+1:end], m.types[k]; ps=m.glob_params)

            idx = new_idx
        end
        m.shares = softmax(x[end - lg-(m.K)+1:end-lg])
    end
    m
end




function lower(m::MixModel)
    if m.K == 1
        lower_dict = Dict(zip(get_fields(m.types[1]), lower(m.types[1])))
        return Float64.(vcat([lower_dict[k] for k in m.local_params[1]], [lower_dict[k] for k in m.glob_params]))
    else
        loc_lowers = map(1:m.K) do i
            σ = m.types[i]
            lower_dict = Dict(zip(get_fields(σ), lower(σ)))
            [lower_dict[k] for k in m.local_params[i]]
        end
        lower_dict = Dict(zip(get_fields(m.types[1]), lower(m.types[1])))
        return Float64.(vcat(loc_lowers..., [-4.1 for i in 1:m.K], [lower_dict[k] for k in m.glob_params]))
    end
end

function upper(m::MixModel)
    if m.K == 1
        upper_dict = Dict(zip(get_fields(m.types[1]), upper(m.types[1])))
        return Float64.(vcat([upper_dict[k] for k in m.local_params[1]], [upper_dict[k] for k in m.glob_params]))
    else
        loc_uppers = map(1:m.K) do i
            σ = m.types[i]
            upper_dict = Dict(zip(get_fields(σ), upper(σ)))
            [upper_dict[k] for k in m.local_params[i]]
        end
        upper_dict = Dict(zip(get_fields(m.types[1]), upper(m.types[1])))
        return Float64.(vcat(loc_uppers..., [4.1 for i in 1:m.K], [upper_dict[k] for k in m.glob_params]))
    end
end


function param_names(m::MixModel)
    if m.K == 1
        namn = collect(fieldnames(typeof(m.types[1])))
        namn = [namn...]
        return namn
    else
        namn = []
        for k in 1:m.K
            namn = [namn..., fieldnames(typeof(m.types[k]))...]
        end
        namn = [namn..., [Symbol("share_"*string(k)) for k in 1:m.K]...]
        return collect(namn)
    end
end



function param_numbered_names(m::MixModel)
    if m.K == 1
        namn = collect(get_fields(m.types[1]))
        namn = [namn...]
        return namn
    else
        namn = []
        for k in 1:m.K
            param_names = collect(get_fields(m.types[k]))
            new_names = [Symbol(string(s)*string(k)) for s in param_names]
            namn = [namn..., new_names...]
        end
        namn = [namn..., [Symbol("share_"*string(k)) for k in 1:m.K]...]
        return collect(namn)
    end
end


function print_params(m)
    names = param_numbered_names(m)
    params = get_params(m)
    res = map(zip(names, params)) do (n, p)
        n => p
    end
    return res
end

###################################################
#%%

function perf(X, σ::Preference)
    preds = predict(X, σ)
    ll = log.(calc_probs(preds, X[!,:choice_x]) .+ 0.000000001)
    ll = sum(ll)
    return ll
end

function accuracy(X, σ::Preference)
    preds = predict(X, σ)
    accu = accu(preds, X[!,:choice_x])
    return mean(accu)
end

function predict(X, model::MixModel)
    pred = zeros(size(X)[1])
    for m in 1:model.K
        pred = pred .+ predict(X, model.types[m]) .* model.shares[m]
    end
    pred
end

function perf(X, model::MixModel)
    pred_names = [Symbol("pred_"*string(m)) for m in 1:model.K];
    ll_names = [Symbol("ll_"*string(m)) for m in 1:model.K]
    for m in 1:model.K
        X[!, pred_names[m]] = (predict(X, model.types[m]))
        X[!, ll_names[m]] = log.(calc_probs(X[!, pred_names[m]], X[!,:choice_x]) .+ 0.000000001)
    end
    ll_vals = combine(groupby(X[!,[:sid, ll_names...]], :sid), ll_names .=> sum)
    if length(model.shares) < model.K
        println("something is off!")
        println(model.shares)
        println(model)
    end
    p_vals = exp.(Matrix(ll_vals[!,2:end]))*model.shares
    p_vals[p_vals .<= 0] .= eps()

    ll = sum(log.(p_vals))
    return ll
end

function accuracy(X, model)
    pred_names = [Symbol("pred_"*string(m)) for m in 1:model.K];
    accu_names = [Symbol("accu_"*string(m)) for m in 1:model.K]
    for m in 1:model.K
        X[!, pred_names[m]] = (predict(X, model.types[m]))
        X[!, accu_names[m]] = accu(X[!, pred_names[m]], X[!,:choice_x])
    end
    accu_vals = combine(groupby(X[!,[:sid, accu_names...]], :sid), accu_names .=> mean)
    accu_max = map(eachrow(accu_vals)) do row
        maximum([row[k] for k in 2:model.K +1])
    end
    return mean(accu_max)
end

function assign_type!(X, model)
    pred_names = [Symbol("pred_"*string(m)) for m in 1:model.K];
    ll_names = [Symbol("ll_"*string(m)) for m in 1:model.K]
    for m in 1:model.K
        X[!, pred_names[m]] = predict(X, model.types[m])
        X[!, ll_names[m]] = log.(calc_probs(X[!, pred_names[m]], X[!,:choice_x]) .+ 0.000000001)
    end
    ll_vals = combine(groupby(X[!,[:sid, ll_names...]], :sid), ll_names .=> sum)
    type_dict = map(eachrow(ll_vals)) do row
        row[:sid] => argmax([row[k] for k in 2:model.K +1])
    end |> Dict
    typ_f(sid) = type_dict[sid]
    X[!, :type] = typ_f.(X.sid)
    X
end

function per_type_perf(X, model)
    pred_names = [Symbol("pred_"*string(m)) for m in 1:model.K];
    ll_names = [Symbol("ll_"*string(m)) for m in 1:model.K]
    res = map(1:model.K) do k
        typ_df = @subset(X, :type .== k)
        perf(typ_df, model.types[k])
    end
    ll = sum(res)
    return ll
end

#%%
function gen_opt_fun(;time_limit=nothing, opt_f="Ipopt")
    function Ipopt_f()
        model = Model(Ipopt.Optimizer)
        set_time_limit_sec(model, time_limit)
        model
    end
    function NLopt_f()
        model = Model(NLopt.Optimizer)
        set_optimizer_attribute(model, "algorithm", :GN_ORIG_DIRECT_L)
        set_optimizer_attribute(model, "maxtime", time_limit)
        model
    end

    if opt_f == "Ipopt"
        return Ipopt_f
    elseif opt_f == "NLopt"
        return NLopt_f
    end
end


function opt_full_model(X, model;opt_f="LBFGS", time_limit=60)
    init_x = get_params_for_opt(model)
    println("opt_full_model")
    function wrap_f(x; model=model)
        model = set_params_from_opt!(x, model)
        return -perf(X, model)
    end
    lower_lim = lower(model)
    upper_lim = upper(model)

    od = Optim.OnceDifferentiable(wrap_f, init_x; autodiff=:forward);
    inner_optimizer = LBFGS(linesearch=LineSearches.BackTracking())
    if opt_f == "NM"
        inner_optimizer = NelderMead()
    elseif opt_f == "SAMIN"
        inner_optimizer = SAMIN(verbosity=1, nt=50, ns=50, rt=0.97)
    end
    if opt_f != "SAMIN"
        opt = Optim.optimize(od, lower_lim, upper_lim, init_x, Fminbox(inner_optimizer), Optim.Options(time_limit=time_limit))
    else
        opt = Optim.optimize(od, lower_lim, upper_lim, init_x, inner_optimizer, Optim.Options(time_limit=time_limit, iterations=2*10^6))
    end
    opt_vals = Optim.minimizer(opt)
    println(opt)
    model = set_params_from_opt!(opt_vals, model)
    model
end

function opt_full_model_Jump(X, m; optimizer_fun=gen_opt_fun())
    init_x = get_params_for_opt(m)
    println("opt_full_m")
    function wrap_f(x...; m=m)
        m = set_params_from_opt!([x...], m)
        return -perf(X, m)
    end
    lower_lim = lower(m)
    upper_lim = upper(m)

    model = optimizer_fun()
    @variable(model, upper_lim[i] >= x_var[i=1:length(init_x)] >= lower_lim[i] )
    set_start_value.(x_var, init_x)
    shares_start = length(init_x) - length(m.glob_params) - m.K +1
    shares_end = length(init_x) - length(m.glob_params)
    @constraint(model, sum(x_var[i] for i in shares_start:shares_end) <= 1)
    register(model, :wrap_perf, length(init_x), wrap_f; autodiff=true)
    @NLobjective(model, Min, wrap_perf(x_var...))
    JuMP.optimize!(model)
    opt_vals = value.(x_var)
    m = set_params_from_opt!(opt_vals, m)
    m
end


function opt_shares(X, m; optimizer_fun=gen_opt_fun())
    init_x = m.shares[1:end-1]
    function wrap_f(x...; m=m)
        m.shares = [x..., 1-sum(x)]
        return -perf(X, m)
    end
    lower_lim = zeros(m.K-1)
    upper_lim = ones(m.K-1)

    model = optimizer_fun()
    @variable(model, upper_lim[i] >= x_var[i=1:length(init_x)] >= lower_lim[i] )
    @constraint(model, sum(x_var[i] for i in 1:(m.K-1)) <= 1)
    set_start_value.(x_var, init_x)
    register(model, :wrap_perf, length(init_x), wrap_f; autodiff=true)
    @NLobjective(model, Min, wrap_perf(x_var...))
    JuMP.optimize!(model)
    opt_vals = value.(x_var)
    m.shares = [opt_vals..., 1 - sum(opt_vals)]
    m
end




function bb_opt_glob_params(X, model; subset="all")
    init_x = get_glob_params(model.glob_params; subset=subset)
    function wrap_f(x; model=model)
        g = set_glob_params(x, model.glob_params; subset=subset)
        model.glob_params = g
        return -per_type_perf(X, model)
    end
    lower_lim = lower(model)
    upper_lim = upper(model)

    bb_res = bboptimize(wrap_f; SearchRange = collect(zip(lower_lim, upper_lim)), TraceInterval=30)
    bb_opt_params = best_candidate(bb_res)
    model.glob_params = set_glob_params(bb_opt_params, model.glob_params; subset=subset)
    model
end


function bb_opt_model(X, model; time_limit=60., paralell=false, pop_size=500)
    init_x = get_params_for_opt(model)
    println("opt_full_model")
    function wrap_f(x; model=model)
        model = set_params_from_opt!(x, model)
        return -perf(X, model)
    end
    lower_lim = lower(model)
    upper_lim = upper(model)
    bb_res = bboptimize(wrap_f; SearchRange = collect(zip(lower_lim, upper_lim)), MaxTime = time_limit, TraceInterval=30, PopluationSize=pop_size)
    bb_opt_params = best_candidate(bb_res)
    model = set_params_from_opt!(bb_opt_params, model)
    model
end

function SEs(X, model_in)
    m1 = deepcopy(model_in)
    function f(x...; m=deepcopy(model_in))
        m = set_params_from_opt!([x...], m)
        return -perf(X, m)
    end
    opt_x = get_params_for_opt(model_in)
    hess = ForwardDiff.hessian(x -> f(x...), opt_x)
    var_cov_matrix = inv(hess)
    βs = opt_x
    temp = diag(var_cov_matrix)
    ses = sqrt.(temp)
    t_stats = βs./ ses
    namn = param_numbered_names(model_in)
    model_idxs = vcat([k.*ones(length(get_params(model_in.types[k]; ps=model_in.local_params[k]))) for k in 1:model_in.K]..., collect(1:model_in.K)...)
    res_vec = map(zip(namn, βs, ses, t_stats, model_idxs)) do (param, β, SE, t_stat, type)
        (;param, β, SE, t_stat, type)
    end
    DataFrame(res_vec)
end


end
